/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/hlo_module_dce.h"

#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"

namespace xla {
namespace {

class HloModuleDceTest : public HloTestBase {
 protected:
  HloModuleDceTest() {}

  // Returns whether the given instruction exists in the given computation.
  bool HasInstruction(const HloComputation& computation,
                      const HloInstruction* instruction) {
    return std::find(computation.instructions().begin(),
                     computation.instructions().end(),
                     instruction) != computation.instructions().end();
  }

  // Returns whether the while instruction with name 'while_name' in
  // 'computation' passes through its tuple element at 'tuple_index' from
  // parameter to root instruction.
  bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation,
                                           const string& while_name,
                                           const int64 tuple_index) {
    for (auto* instruction : computation->instructions()) {
      if (instruction->opcode() == HloOpcode::kWhile &&
          instruction->name() == while_name) {
        auto* while_body_comp = instruction->while_body();
        auto* while_body_param = while_body_comp->parameter_instruction(0);
        auto* while_body_root = while_body_comp->root_instruction();
        if (while_body_root->opcode() != HloOpcode::kTuple) {
          return false;
        }
        auto* operand = while_body_root->operand(tuple_index);
        if (operand->opcode() == HloOpcode::kGetTupleElement &&
            operand->tuple_index() == tuple_index &&
            operand->operand(0) == while_body_param) {
          return true;
        }
        return false;
      }
    }
    return false;
  }
};

// Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body {
    loop_var.1 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
    ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
  }
  SimpleLoop.condition {
    loop_var.2 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
    constant.2 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
  }
  ENTRY SimpleLoop {
    constant.3 = s32[] constant(0)
    constant.4 = s32[3]{0} constant({0, 1, 2})
    tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
    ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
      SimpleLoop.condition, body=SimpleLoop.body
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 0));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 1));
}

// Tests a while loop with one unused output (which is used in the while loop
// body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body {
    loop_var.1 = (s32[], f32[]) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1
    constant.2 = f32[] constant(1.0)
    rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform
    add.1 = s32[] add(get-tuple-element.2, constant.2)
    ROOT tuple = (s32[], f32[]) tuple(add, add.1)
  }
  SimpleLoop.condition {
    loop_var.2 = (s32[], f32[]) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
    constant.3 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3)
  }
  ENTRY SimpleLoop {
    constant.4 = s32[] constant(0)
    constant.5 = f32[] constant(0.0)
    tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5)
    while = (s32[], f32[]) while(tuple.1), condition=
      SimpleLoop.condition, body=SimpleLoop.body
    ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 0));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 1));
}

// Tests that a while loop with one dead tuple element at {1} has its while
// loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body {
    loop_var.1 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
    ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
  }
  SimpleLoop.condition {
    loop_var.2 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
    constant.2 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
  }
  ENTRY SimpleLoop {
    constant.3 = s32[] constant(0)
    constant.4 = s32[3]{0} constant({0, 1, 2})
    tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
    while = (s32[], s32[3]{0}) while(tuple.1), condition=
      SimpleLoop.condition, body=SimpleLoop.body
    ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  // While tuple element {1} should not be pass-through before ModuleDCE.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 1));
  EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 0));
  // While tuple element {1} should now be pass-through after ModuleDCE.
  EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                  "while", 1));
}

// Tests that a tuple element {1} used by condition computation (which appears
// dead in while.body{1} and at while.result{1}) propgates liveness of this
// tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body {
    loop_var.1 = (s32[], s32[]) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1
    multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2)
    ROOT tuple = (s32[], s32[]) tuple(add, multiply)
  }
  SimpleLoop.condition {
    loop_var.2 = (s32[], s32[]) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
    constant.2 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
  }
  ENTRY SimpleLoop {
    constant.3 = s32[] constant(0)
    constant.4 = s32[] constant(0)
    tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4)
    while = (s32[], s32[]) while(tuple.1), condition=
      SimpleLoop.condition, body=SimpleLoop.body
    ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  // While tuple element {1} should not be pass-through before ModuleDCE.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 1));
  EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 0));
  // While tuple element {1} still be pass-through after ModuleDCE.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while", 1));
}

// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
// two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body0 {
    loop_var.1 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
    ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
  }
  SimpleLoop.condition0 {
    loop_var.2 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
    constant.2 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
  }
  SimpleLoop.body1 {
    loop_var.3 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
    constant.3 = s32[] constant(1)
    add.1 = s32[] add(get-tuple-element.4, constant.3)
    get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
    multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
    ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
  }
  SimpleLoop.condition1 {
    loop_var.4 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
    constant.4 = s32[] constant(5)
    ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
  }
  ENTRY SimpleLoop {
    constant.5 = s32[] constant(0)
    constant.6 = s32[3]{0} constant({0, 1, 2})
    tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6)
    while.1 = (s32[], s32[3]{0}) while(tuple.2), condition=
      SimpleLoop.condition0, body=SimpleLoop.body0
    get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0
    tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
    while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
      SimpleLoop.condition1, body=SimpleLoop.body1
    ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.1", 1));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.2", 1));
  EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
  // After HloModuleDCE while.1 and while.2 should have pass-thru elements,
  // after being modified to pass through unused tuple element {1}.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.1", 0));
  EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                  "while.1", 1));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.2", 0));
  EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                  "while.2", 1));
}

// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
// while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
  auto module = ParseHloString(R"(
  HloModule SimpleLoop
  SimpleLoop.body0 {
    loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
    get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1
    constant.1 = s32[] constant(1)
    add = s32[] add(get-tuple-element.1, constant.1)
    get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0
    multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
    ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add)
  }
  SimpleLoop.condition0 {
    loop_var.2 = (s32[3]{0}, s32[]) parameter(0)
    get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1
    constant.2 = s32[] constant(5)
    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
  }
  SimpleLoop.body1 {
    loop_var.3 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0
    constant.3 = s32[] constant(1)
    add.1 = s32[] add(get-tuple-element.4, constant.3)
    get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1
    multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5)
    ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1)
  }
  SimpleLoop.condition1 {
    loop_var.4 = (s32[], s32[3]{0}) parameter(0)
    get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0
    constant.4 = s32[] constant(5)
    ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4)
  }
  ENTRY SimpleLoop {
    constant.5 = s32[] constant(0)
    constant.6 = s32[3]{0} constant({0, 1, 2})
    tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5)
    while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition=
      SimpleLoop.condition0, body=SimpleLoop.body0
    get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1
    tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6)
    while.2 = (s32[], s32[3]{0}) while(tuple.3), condition=
      SimpleLoop.condition1, body=SimpleLoop.body1
    ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0
  })")
                    .ValueOrDie();

  HloModuleDCE dce;
  // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.1", 0));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.2", 1));
  EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
  // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements.
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.1", 1));
  EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                  "while.1", 0));
  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                   "while.2", 0));
  EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
                                                  "while.2", 1));
}

}  // namespace
}  // namespace xla
